/**
 * @license
 * Copyright 2019 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import {test_util} from '@tensorflow/tfjs-core';
import {loadQnA, UniversalSentenceEncoderQnA} from './use_qna';

describe('Universal Sentence Encoder QNA', () => {
  let qna: UniversalSentenceEncoderQnA;
  beforeAll(async () => {
    qna = await loadQnA();
  });

  it('basic usage', async () => {
    const result = await qna.embed({
      queries: ['what is the weather today?'],
      responses: ['today is cloudy.']
    });
    test_util.expectArraysClose(result['queryEmbedding'].dataSync(), [
      2.7645488,   0.30286145,  0.15519771,  3.3303149,   0.73709464,
      -0.24267645, 0.42912522,  0.6104698,   1.402183,    -2.319392,
      0.1386897,   0.20799994,  3.1938918,   0.48018247,  -0.2336823,
      3.1762602,   -0.15730904, 0.620998,    -0.3168982,  0.01887584,
      -0.03040927, 0.06796362,  0.39682847,  -0.8980976,  -0.23845802,
      0.15611805,  0.2725699,   -1.8144639,  0.5021897,   -0.13243707,
      -0.16235027, -0.8155376,  0.03923966,  0.12180302,  -0.7676933,
      3.2731714,   0.16963466,  0.33868262,  -0.20740426, -0.6665582,
      0.04915003,  -0.04177394, 0.37417072,  0.26179722,  0.82453203,
      0.13558507,  -0.02828247, 0.10979467,  0.4984138,   0.1251011,
      -0.40570533, 0.35685825,  0.4352448,   3.1945264,   0.19645983,
      -0.08614383, 0.6918269,   -0.04146867, 0.00547367,  0.31328753,
      -0.00504928, -0.9106121,  -0.09238561, 0.00692027,  0.3602638,
      -0.55486757, 0.58284456,  0.07299471,  0.24222873,  -1.7512876,
      -2.7701135,  0.16699581,  0.12134986,  0.4786817,   0.07161791,
      0.6970726,   -0.15573968, -0.8637686,  -0.8269252,  2.7347164,
      -0.28828332, 0.4646879,   -0.23542976, 0.1885213,   -0.07906074,
      -0.4910207,  0.08913047,  -0.09251469, -1.2985612,  -1.2836055,
      -0.13542742, 3.1167097,   -0.1678529,  -0.21905473, -1.3630934,
      0.2087331,   0.01345341,  0.34093782,  -0.02854175, 3.029753
    ]);
    test_util.expectArraysClose(result['responseEmbedding'].dataSync(), [
      -2.7903643,  -0.7020997,  -0.5393383,  0.7589595,   1.8675963,
      0.13674915,  -0.80103,    -0.34954107, 1.6652802,   1.7186067,
      -0.80345714, -0.18878363, -0.320481,   1.3214048,   0.9120344,
      1.707129,    -1.3292489,  -0.2988205,  0.19789799,  -1.4207102,
      1.1632164,   1.0613979,   -0.50598425, -0.743629,   0.45145196,
      -1.3591434,  -1.4505718,  0.12148909,  0.43211582,  -0.97200686,
      1.0016091,   -0.5041195,  1.1606078,   -0.07992788, -0.60910803,
      -0.07444222, -0.04325415, 0.8746788,   -1.3428974,  -0.9737125,
      1.4358486,   -0.67734474, 2.0742533,   0.6386704,   0.9713414,
      1.4017808,   -1.2397543,  1.6230887,   -0.36473927, -0.40323737,
      -0.9649109,  1.1477537,   1.349556,    1.073676,    1.2211813,
      -0.7674891,  1.7596902,   -0.67888576, -0.6901066,  1.1336823,
      1.0909919,   -0.03635395, -0.9884136,  -1.5454148,  -0.20521395,
      -1.0088263,  0.15151767,  -0.83107924, 1.7557824,   0.34723148,
      0.78867906,  0.9948809,   -0.9609835,  1.526174,    0.93398654,
      0.957069,    -1.3943961,  -1.0564212,  0.01647631,  2.260069,
      -1.8463705,  2.1627388,   -0.15955001, -1.8080655,  -0.63866234,
      0.21116942,  0.5413755,   0.63926923,  0.29422516,  -1.2522588,
      0.11680362,  0.09367324,  -1.5876546,  -1.6012754,  -0.26643804,
      0.608659,    2.2008464,   1.0552162,   -1.4435272,  -1.2838663
    ]);
  });
  it('with context', async () => {
    const result = await qna.embed({
      queries: ['what is the weather today?'],
      responses: ['today is cloudy.'],
      contexts: ['yesterday was sunny.']
    });
    test_util.expectArraysClose(result['queryEmbedding'].dataSync(), [
      2.764549732208252,     0.3028636574745178,    0.15519654750823975,
      3.3303146362304688,    0.7370948791503906,    -0.24267707765102386,
      0.4291246831417084,    0.6104708909988403,    1.4021832942962646,
      -2.3193931579589844,   0.1386888325214386,    0.20800015330314636,
      3.193892002105713,     0.48018214106559753,   -0.23368141055107117,
      3.1762609481811523,    -0.15730996429920197,  0.62099689245224,
      -0.31689780950546265,  0.018875854089856148,  -0.030409209430217743,
      0.0679638683795929,    0.3968282639980316,    -0.8980970978736877,
      -0.2384580820798874,   0.1561172604560852,    0.2725693881511688,
      -1.8144654035568237,   0.5021904706954956,    -0.1324377954006195,
      -0.16235150396823883,  -0.8155357837677002,   0.039240460842847824,
      0.12180358916521072,   -0.7676945328712463,   3.2731711864471436,
      0.16963325440883636,   0.33868226408958435,   -0.20740365982055664,
      -0.6665595769882202,   0.049150239676237106,  -0.04177480190992355,
      0.3741709589958191,    0.2617974281311035,    0.824533998966217,
      0.1355849802494049,    -0.028282370418310165, 0.10979544371366501,
      0.4984150528907776,    0.12510059773921967,   -0.4057057201862335,
      0.35685938596725464,   0.43524476885795593,   3.194526433944702,
      0.19646026194095612,   -0.08614348620176315,  0.6918268799781799,
      -0.0414697602391243,   0.0054745920933783054, 0.313288152217865,
      -0.005049782805144787, -0.9106109142303467,   -0.09238557517528534,
      0.006920729763805866,  0.36026570200920105,   -0.5548665523529053,
      0.582845151424408,     0.07299471646547318,   0.2422289252281189,
      -1.7512867450714111,   -2.7701122760772705,   0.1669962853193283,
      0.12134931981563568,   0.4786826968193054,    0.07161817699670792,
      0.6970734000205994,    -0.1557406187057495,   -0.8637679219245911,
      -0.826924741268158,    2.7347168922424316,    -0.28828272223472595,
      0.46468833088874817,   -0.2354297637939453,   0.18852107226848602,
      -0.07906034588813782,  -0.49102115631103516,  0.08913090825080872,
      -0.09251506626605988,  -1.2985612154006958,   -1.2836053371429443,
      -0.13542881608009338,  3.1167097091674805,    -0.16785310208797455,
      -0.2190552055835724,   -1.3630945682525635,   0.2087334841489792,
      0.013453599996864796,  0.34093764424324036,   -0.028541648760437965,
      3.029752731323242
    ]);
    test_util.expectArraysClose(result['responseEmbedding'].dataSync(), [
      -2.747086524963379,   -1.5298696756362915,   -0.8569912910461426,
      -0.04533401131629944, 1.6793564558029175,    -0.08554036915302277,
      -1.0194844007492065,  0.47660741209983826,   1.5304087400436401,
      1.2770110368728638,   -0.9437439441680908,   -0.22166286408901215,
      -0.2539377212524414,  1.4826762676239014,    0.1694871485233307,
      1.4965760707855225,   -1.142553687095642,    -0.6277444958686829,
      0.17439526319503784,  -1.0646206140518188,   1.1420629024505615,
      1.1044495105743408,   -0.02009904757142067,  -1.1047433614730835,
      0.2740751802921295,   -0.9673320651054382,   -1.5674513578414917,
      -0.3471321761608124,  0.7423182725906372,    -0.45107758045196533,
      0.5625700354576111,   -0.4610292911529541,   1.4592750072479248,
      -0.10152439773082733, -0.7695645093917847,   0.7896363735198975,
      0.6434444785118103,   0.7013059258460999,    -1.2638391256332397,
      -1.3608455657958984,  1.665994644165039,     -0.6346866488456726,
      2.243438720703125,    1.2208253145217896,    1.0088047981262207,
      1.4481457471847534,   -1.2708708047866821,   1.8359601497650146,
      -0.2303360253572464,  0.2972204387187958,    -1.1097303628921509,
      1.205251693725586,    1.2754255533218384,    0.2633049786090851,
      1.5643799304962158,   -1.1438292264938354,   1.8493810892105103,
      -1.0438262224197388,  -0.5506157279014587,   1.2074977159500122,
      1.0494587421417236,   -0.1277640461921692,   -1.2264611721038818,
      -1.5335519313812256,  -0.29750311374664307,  -1.171984076499939,
      0.33427461981773376,  -0.34870827198028564,  2.031810760498047,
      0.10940468311309814,  0.7610474824905396,    0.9941542744636536,
      -1.0595879554748535,  1.4200491905212402,    0.7450745105743408,
      0.14273041486740112,  -1.4283344745635986,   -1.7344152927398682,
      -0.5243281722068787,  1.6276650428771973,    -1.8341976404190063,
      2.2950599193573,      -0.11798099428415298,  -1.627362847328186,
      -0.678684651851654,   0.24903929233551025,   -0.3194371163845062,
      1.2070250511169434,   0.15791155397891998,   -1.0031750202178955,
      0.24275082349777222,  -0.007341892924159765, -1.6534773111343384,
      -1.4808708429336548,  -0.33998817205429077,  0.8144207000732422,
      2.1109611988067627,   0.08290252834558487,   -1.0134201049804688,
      -0.10679633170366287
    ]);
  });
  it('fail when contexts abd responses length do not match', () => {
    expect(() => qna.embed({
      queries: ['what is the weather today?'],
      responses: ['today is cloudy.'],
      contexts: ['test1.', 'test2']
    })).toThrowError();
  });
});
